/*
 * Copyright (C) 2016 Edward Raff
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package com.jstarcraft.ai.jsat.clustering;

import static java.lang.Math.max;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.linear.VecPaired;
import com.jstarcraft.ai.jsat.linear.distancemetrics.DistanceMetric;
import com.jstarcraft.ai.jsat.linear.distancemetrics.EuclideanDistance;
import com.jstarcraft.ai.jsat.linear.vectorcollection.DefaultVectorCollection;
import com.jstarcraft.ai.jsat.linear.vectorcollection.VectorCollection;
import com.jstarcraft.ai.jsat.linear.vectorcollection.VectorCollectionUtils;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.FibHeap;
import com.jstarcraft.ai.jsat.utils.Tuple3;
import com.jstarcraft.ai.jsat.utils.UnionFind;
import com.jstarcraft.core.utility.Integer2IntegerKeyValue;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;

/**
 * HDBSCAN is a density based clustering algorithm that is an improvement over
 * {@link DBSCAN}. Unlike its predecessor, HDBSCAN works with variable density
 * datasets and does not need a search radius to be specified. The original
 * paper presents HDBSCAN with two parameters {@link #setMinPoints(int)
 * m<sub>pts</sub>} and {@link #setMinClusterSize(int) m<sub>clSize</sub>}, but
 * recomends that they can be set to the same value and effectively behave as if
 * only one parameter exists. This implementation allows for setting both
 * independtly, but the single parameter constructors will use the same value
 * for both parameters. <br>
 * NOTE: The current implementation has O(N<sup>2</sup>) run time, though this
 * may be improved in the future with more advanced algorithms.<br>
 * <br>
 * See: Campello, R. J. G. B., Moulavi, D., & Sander, J. (2013). Density-Based
 * Clustering Based on Hierarchical Density Estimates. In J. Pei, V. Tseng, L.
 * Cao, H. Motoda, & G. Xu (Eds.), Advances in Knowledge Discovery and Data
 * Mining (pp. 160–172). Springer Berlin Heidelberg.
 * doi:10.1007/978-3-642-37456-2_14
 * 
 * @author Edward Raff
 */
public class HDBSCAN implements Clusterer, Parameterized {
    private DistanceMetric dm;
    /**
     * minimium number of points
     */
    private int m_pts;
    private int m_clSize;
    private VectorCollection<Vec> vc;

    /**
     * Creates a new HDBSCAN object using a threshold of 15 points to form a
     * cluster.
     */
    public HDBSCAN() {
        this(15);
    }

    /**
     * Creates a new HDBSCAN using the simplified form, where the only parameter is
     * a single value.
     *
     * @param m_pts the minimum number of points needed to form a cluster and the
     *              number of neighbors to consider
     */
    public HDBSCAN(int m_pts) {
        this(new EuclideanDistance(), m_pts);
    }

    /**
     * Creates a new HDBSCAN using the simplified form, where the only parameter is
     * a single value.
     *
     * @param dm    the distance metric to use for finding nearest neighbors
     * @param m_pts the minimum number of points needed to form a cluster and the
     *              number of neighbors to consider
     */
    public HDBSCAN(DistanceMetric dm, int m_pts) {
        this(dm, m_pts, m_pts, new DefaultVectorCollection<>());
    }

    /**
     * Creates a new HDBSCAN using the simplified form, where the only parameter is
     * a single value.
     *
     * @param dm    the distance metric to use for finding nearest neighbors
     * @param m_pts the minimum number of points needed to form a cluster and the
     *              number of neighbors to consider
     * @param vcf   the vector collection to use for accelerating nearest neighbor
     *              queries
     */
    public HDBSCAN(DistanceMetric dm, int m_pts, VectorCollection<Vec> vcf) {
        this(dm, m_pts, m_pts, vcf);
    }

    /**
     * Creates a new HDBSCAN using the full specification of the algorithm, where
     * two parameters may be altered. In the simplified version both parameters
     * always have the same value.
     *
     * @param dm       the distance metric to use for finding nearest neighbors
     * @param m_pts    the number of neighbors to consider, acts as a smoothing over
     *                 the density estimate
     * @param m_clSize the minimum number of data points needed to form a cluster
     * @param vc       the vector collection to use for accelerating nearest
     *                 neighbor queries
     */
    public HDBSCAN(DistanceMetric dm, int m_pts, int m_clSize, VectorCollection<Vec> vc) {
        this.dm = dm;
        this.m_pts = m_pts;
        this.m_clSize = m_clSize;
        this.vc = vc;
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public HDBSCAN(HDBSCAN toCopy) {
        this.dm = dm.clone();
        this.m_pts = toCopy.m_pts;
        this.m_clSize = toCopy.m_clSize;
        this.vc = toCopy.vc.clone();
    }

    /**
     * 
     * @param m_clSize the minimum number of data points needed to form a cluster
     */
    public void setMinClusterSize(int m_clSize) {
        this.m_clSize = m_clSize;
    }

    /**
     * 
     * @return the minimum number of data points needed to form a cluster
     */
    public int getMinClusterSize() {
        return m_clSize;
    }

    /**
     * Sets the distance metric to use for determining closeness between data points
     * 
     * @param dm the distance metric to determine nearest neighbors with
     */
    public void setDistanceMetrics(DistanceMetric dm) {
        this.dm = dm;
    }

    /**
     * 
     * @return the distance metric to determine nearest neighbors with
     */
    public DistanceMetric getDistanceMetrics() {
        return dm;
    }

    /**
     * 
     * @param m_pts the number of neighbors to consider, acts as a smoothing over
     *              the density estimate
     */
    public void setMinPoints(int m_pts) {
        this.m_pts = m_pts;
    }

    /**
     * 
     * @return the number of neighbors to consider, acts as a smoothing over the
     *         density estimate
     */
    public int getMinPoints() {
        return m_pts;
    }

    @Override
    public HDBSCAN clone() {
        return new HDBSCAN(this);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        if (designations == null)
            designations = new int[dataSet.size()];

        @SuppressWarnings("unchecked")
        final List<Vec> X = dataSet.getDataVectors();
        final int N = X.size();
        List<Double> cache = dm.getAccelerationCache(X, parallel);
        VectorCollection<Vec> X_vc = vc.clone();
        X_vc.build(parallel, X, dm);
        // 1. Compute the core distance w.r.t. m_pts for all data objects in X.
        /*
         * (Core Distance): The core distance of an object x_p ∈ X w.r.t. m_pts,
         * d_core(x_p), is the distance from x_p to its m_pts-nearest neighbor (incl.
         * x_p)
         */
        List<List<? extends VecPaired<Vec, Double>>> allNearestNeighbors = VectorCollectionUtils.allNearestNeighbors(X_vc, X, m_pts, parallel);
        double[] core = new double[N];
        for (int i = 0; i < N; i++)
            core[i] = allNearestNeighbors.get(i).get(m_pts - 1).getPair();

        // 2. Compute an MST of G_{m_pts}, the Mutual Reachability Graph.

        // prims algorithm from Wikipedia
        double[] C = new double[N];
        Arrays.fill(C, Double.MAX_VALUE);
        int[] E = new int[N];
        Arrays.fill(E, -1);// -1 "a special flag value indicating that there is no edge connecting v to
                           // earlier vertices"

        FibHeap<Integer> Q = new FibHeap<>();
        List<FibHeap.FibNode<Integer>> q_nodes = new ArrayList<>(N);
        for (int i = 0; i < N; i++)
            q_nodes.add(Q.insert(i, C[i]));
        Set<Integer> F = new HashSet<>();

        /**
         * First 2 indicate the edges, 3d value is the weight
         */
        List<Tuple3<Integer, Integer, Double>> mst_edges = new ArrayList<>(N * 2);

        while (Q.size() > 0) {
            // a. Find and remove a vertex v from Q having the minimum possible value of
            // C[v]
            FibHeap.FibNode<Integer> node = Q.removeMin();
            int v = node.getValue();
            q_nodes.set(v, null);
            // b. Add v to F and, if E[v] is not the special flag value, also add E[v] to F
            F.add(v);

            if (E[v] >= 0)
                mst_edges.add(new Tuple3<>(v, E[v], C[v]));

            /*
             * c. Loop over the edges vw connecting v to other vertices w. For each such
             * edge, if w still belongs to Q and vw has smaller weight than C[w]: Set C[w]
             * to the cost of edge vw Set E[w] to point to edge vw.
             */

            for (int w = 0; w < N; w++) {
                FibHeap.FibNode<Integer> w_node = q_nodes.get(w);
                if (w_node == null)// this node is already in F
                    continue;

                double mutual_reach_dist_vw = max(core[v], max(core[w], dm.dist(v, w, X, cache)));
                if (mutual_reach_dist_vw < C[w]) {
                    Q.decreaseKey(w_node, mutual_reach_dist_vw);
                    C[w] = mutual_reach_dist_vw;
                    E[w] = v;
                }

            }

        }

        // prim is done, we have the MST!

        /*
         * 3. Extend the MST to obtain MSText, by adding for each vertex a “self edge”
         * with the core distance of the corresponding object as weight
         */

        for (int i = 0; i < N; i++)
            mst_edges.add(new Tuple3<>(i, i, core[i]));

        // 4. Extract the HDBSCAN hierarchy as a dendrogram from MSText:

        List<UnionFind<Integer>> ufs = new ArrayList<>(N);
        for (int i = 0; i < N; i++)
            ufs.add(new UnionFind<>(i));
        // sort edges from smallest weight to largest
        PriorityQueue<Tuple3<Integer, Integer, Double>> edgeQ = new PriorityQueue<>(2 * N, (o1, o2) -> o1.getZ().compareTo(o2.getZ()));
        edgeQ.addAll(mst_edges);

        // everyone starts in their own cluster!
        List<IntList> currentGroups = new ArrayList<>();
        for (int i = 0; i < N; i++) {
            IntArrayList il = new IntArrayList(1);
            il.add(i);
            currentGroups.add(il);
        }

        int next_cluster_label = 0;
        /**
         * List of all the cluster options we have found
         */
        List<List<Integer>> cluster_options = new ArrayList<>();
        /**
         * Stores a list for each cluster. Each value in the sub list is the weight at
         * which that data point was added to the cluster
         */
        List<DoubleArrayList> entry_size = new ArrayList<>();
        DoubleArrayList birthSize = new DoubleArrayList();
        DoubleArrayList deathSize = new DoubleArrayList();
        List<Integer2IntegerKeyValue> children = new ArrayList<>();
        int[] map_to_cluster_label = new int[N];
        Arrays.fill(map_to_cluster_label, -1);

        while (!edgeQ.isEmpty()) {
            Tuple3<Integer, Integer, Double> edge = edgeQ.poll();
            double weight = edge.getZ();
            int from = edge.getX();
            int to = edge.getY();

            if (to == from)
                continue;

            UnionFind<Integer> union_from = ufs.get(from);
            UnionFind<Integer> union_to = ufs.get(to);

            int clust_A = union_from.find().getItem();
            int clust_B = union_to.find().getItem();

            UnionFind<Integer> clust_A_tmp = union_from.find();
            UnionFind<Integer> clust_B_tmp = union_to.find();

            union_from.union(union_to);

            int a_size = currentGroups.get(clust_A).size();
            int b_size = currentGroups.get(clust_B).size();
            int new_size = a_size + b_size;

            int mergedClust;
            int otherClust;
            if (union_from.find().getItem() == clust_A) {
                mergedClust = clust_A;
                otherClust = clust_B;
            } else// other way around
            {
                mergedClust = clust_B;
                otherClust = clust_A;
            }

            if (new_size >= m_clSize && a_size < m_clSize && b_size < m_clSize) {// birth of a new cluster!
                cluster_options.add(currentGroups.get(mergedClust));

                DoubleArrayList dl = new DoubleArrayList(new_size);
                for (int i = 0; i < new_size; i++)
                    dl.add(weight);
                entry_size.add(dl);

                children.add(null);// we have not children!
                birthSize.add(weight);
                deathSize.add(Double.MAX_VALUE);// we don't know yet
                map_to_cluster_label[mergedClust] = next_cluster_label;
                next_cluster_label++;
            } else if (new_size >= m_clSize && a_size >= m_clSize && b_size >= m_clSize) {// birth of a new cluster from the death of two others!
                                                                                          // record the weight that the other two died at
                deathSize.set(map_to_cluster_label[mergedClust], weight);
                deathSize.set(map_to_cluster_label[otherClust], weight);

                // replace with new object so that old references in cluster_options are not
                // altered further
                currentGroups.set(mergedClust, new IntArrayList(currentGroups.get(mergedClust)));

                cluster_options.add(currentGroups.get(mergedClust));
                DoubleArrayList dl = new DoubleArrayList(new_size);
                for (int i = 0; i < new_size; i++)
                    dl.add(weight);
                entry_size.add(dl);

                children.add(new Integer2IntegerKeyValue(map_to_cluster_label[mergedClust], map_to_cluster_label[otherClust]));
                birthSize.add(weight);
                deathSize.add(Double.MAX_VALUE);// we don't know yet
                map_to_cluster_label[mergedClust] = next_cluster_label;
                next_cluster_label++;
            } else if (new_size >= m_clSize) {// existing cluster has grown in size, so add the points and record their weight
                                              // for use later
                                              // index may change, so book keeping update
                if (map_to_cluster_label[mergedClust] == -1)// the people being added are the new owners
                {
                    // set to avoid index out of boudns bellow
                    int c = map_to_cluster_label[mergedClust] = map_to_cluster_label[otherClust];
                    // make sure we keep track of the correct list
                    cluster_options.set(c, currentGroups.get(mergedClust));
                    map_to_cluster_label[otherClust] = -1;
                }

                for (int indx : currentGroups.get(otherClust))
                    try {
                        entry_size.get(map_to_cluster_label[mergedClust]).add(weight);
                    } catch (IndexOutOfBoundsException ex) {
                        ex.printStackTrace();
                        throw new FailedToFitException(ex);
                    }

            }

            currentGroups.get(mergedClust).addAll(currentGroups.get(otherClust));
            currentGroups.set(otherClust, null);

        }

        // Remove the last "cluster" because its the dumb one of everything
        cluster_options.remove(cluster_options.size() - 1);
        entry_size.remove(entry_size.size() - 1);
        birthSize.removeDouble(birthSize.size() - 1);
        deathSize.removeDouble(deathSize.size() - 1);
        children.remove(children.size() - 1);

        /**
         * See equation (3) in paper
         */
        double[] S = new double[cluster_options.size()];
        for (int c = 0; c < S.length; c++) {
            double lambda_min = birthSize.getDouble(c);
            double lambda_max = deathSize.getDouble(c);
            double s = 0;
            for (double f_x : entry_size.get(c))
                s += Math.min(f_x, lambda_max) - lambda_min;
            S[c] = s;
        }

        boolean[] toKeep = new boolean[S.length];
        double[] S_hat = new double[cluster_options.size()];
        Arrays.fill(toKeep, true);
        Queue<Integer> notKeeping = new ArrayDeque<>();

        for (int i = 0; i < S.length; i++) {
            Integer2IntegerKeyValue child = children.get(i);
            if (child == null)// I'm a leaf!
            {
                // for all leaf nodes, set ˆS(C_h)= S(C_h)
                S_hat[i] = S[i];
                continue;
            }
            int il = child.getKey();
            int ir = child.getValue();
            // If S(C_i) < ˆS(C_il)+ ˆ S(C_ir ), set ˆS(C_i)= ˆS(C_il)+ ˆS(C_ir )and set δi
            // =0.
            if (S[i] < S_hat[il] + S_hat[ir]) {
                S_hat[i] = S_hat[il] + S_hat[ir];
                toKeep[i] = false;
            } else// Else: set ˆS(C_i)= S(C_i)and set δ(·) = 0 for all clusters in C_i’s subtrees.
            {
                S_hat[i] = S[i];
                // place children in q to process and set all sub children as not keeping
                notKeeping.add(il);
                notKeeping.add(ir);
                while (!notKeeping.isEmpty()) {
                    int c = notKeeping.poll();
                    toKeep[c] = false;
                    Integer2IntegerKeyValue c_children = children.get(c);
                    if (c_children == null)
                        continue;
                    notKeeping.add(c_children.getKey());
                    notKeeping.add(c_children.getValue());
                }
            }
        }

        // initially fill with -1 indicating it was noise
        Arrays.fill(designations, 0, N, -1);

        int clusters = 0;
        for (int c = 0; c < toKeep.length; c++)
            if (toKeep[c]) {
                for (int indx : cluster_options.get(c))
                    designations[indx] = clusters;
                clusters++;
            }

        return designations;
    }
}
